-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Fix complex arguments and implement float16 lowering #2403
base: complex
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some quick comments
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
59f59b8
to
493c1e8
Compare
devito/passes/iet/dtypes.py
Outdated
params_mapper = {} | ||
|
||
# Lower scalar float16s to pointers and dereference them | ||
for s in FindSymbols('scalars').visit(iet): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you just do visit(iet.parameters)
directly instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like not, since the actual types I'm mapping (e.g. Constant
) aren't nodes they are only caught by FindSymbols
if it's as a reference within some other expression, so we need to visit the body. That said I can probably make this marginally more efficient by making a set of parameters beforehand and checking membership that way
@classmethod | ||
def _load_dtype_mappings(cls, **kwargs): | ||
lang: type[LangBB] = cls._Target.DataManager.lang | ||
ctypes_vector_mapper.update(lang.mapper.get('types', {})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that's a bit tricky because this updates a global ctypes_vector_mapper
which might lead to odd behavior building multiple operators with different languages.
Do you know where it's called and needs those types
? I.e can the mapper be "local" to the operator and passed there?
@@ -15,6 +18,62 @@ | |||
} | |||
|
|||
|
|||
class NoDeclStruct(ctypes.Structure): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZoeLeibowitz I may be wrong but don't you need this type too? please see also where it's used and comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have different use cases because I need a local composite type which still generates a struct in the header file (i.e still produces a struct definition)
devito/passes/iet/dtypes.py
Outdated
|
||
|
||
def lower_complex(iet, lang, compiler): | ||
def lower_dtypes(iet, lang, compiler, sregistry): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how can this not be an @iet_pass
and still work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's called from an @iet_pass
in definitions
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah OK , I maybe have told Mathias already, but imho that thing goes straight into operator/operator.py::Operator::_lower_iet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay, I can make this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments, looking good, but still needs tests
seconded, this is crucial |
Yeah my plan was to write some more extensive tests for complex (since the existing ones didn't catch all of the break cases I've found) and add in a few for half. I want to get all of these changes out of the way first, and then there's a few more things to work out such as math functions being assigned to float16 arrays. Maybe this should be a draft actually? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests looking good. I think there should also be a test to ensure that the use of fp16/complex gets propagated into the variables introduced by CSE
@@ -644,16 +644,18 @@ def test_tensor(self, func1): | |||
def test_complex(self, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The introduction of an override is a good idea here.
I think this wants separating out into a couple of tests:
- Test that the correct expressions are generated
- Test that you can specify derivatives of complex fields
- Test that you can override complex values
- Test that you can mix complex and real symbols on the RHS
Note: do we also need to test that halo exchanges work with complex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, will add some tests for those points.
Re: halo exchanges, I'll get back to you when I figure out what a halo exchange is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A halo exchange is where each MPI rank swaps the data along its edges with its neighbours before advancing to the next timestep. It would probably worth checking that the MPI communication doesn't garble the data somehow when moving it between ranks. You would probably need some equations like:
from devito import Grid, TimeFunction, Eq, Operator
grid = Grid(shape=(4, 4))
x, y = grid.dimensions
t = grid.stepping_dim
u = TimeFunction(name='u', grid=grid)
v = TimeFunction(name='v', grid=grid)
eqns = [Eq(u[t+1, x, y], 1+sympy.I), Eq(v[t+1, x, y], sympy.I*u[t, x, y+1])]
op = Operator(eqns)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I think what I'll do is leave these general complex arithmetic tests where they are (in test_operator
and test_gpu_common
) and add these lower-level ones to the test_dtypes
module. Except this halo exchange one goes in test_mpi
I guess
Fixes some issues that prevented complex float arguments from being passed correctly. Also implements lowering for half-precision floats with the
_Float16
type.Some remaining tasks:
ctypes_vector_mappet
; need some more tests for operator building/rebuilding/pickling with float16, or potentially a different (hopefully less brittle) solution